# https://github.com/EleMisi/ConditionalVAE/blob/master/notebooks/Train_ConditionalVAE.ipynb

# Import all the necessary libraries

import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from keras.preprocessing import image
from keras.models import Sequential, Model
from keras.layers import Dense, Conv2D, MaxPooling2D,MaxPool2D ,UpSampling2D, Flatten, Input
from keras.optimizers import SGD, Adam, Adadelta, Adagrad
from keras import backend as K
import tensorflow as tf
from tqdm import tqdm

from tfFunctionsUtils import load_dataset
##
import tensorflow as tf
from ConvolutionalCondVAE import ConvCVAE, Decoder, Encoder
#
import os
#######################
# Train Step Function #
#######################


def train_step(data, model, optimizer):


    with tf.GradientTape() as tape:

        model_output = model(data, is_train = True)

    trainable_variables = model.trainable_variables
    grads = tape.gradient(model_output['loss'], trainable_variables)
    optimizer.apply_gradients(zip(grads, trainable_variables))

    total_loss = model_output['loss'].numpy().mean()
    recon_loss = model_output['reconstr_loss'].numpy().mean()
    latent_loss = model_output['latent_loss'].numpy().mean()

    return total_loss, recon_loss, latent_loss, model_output['recon_img']


if __name__ == '__main__':

    #######
    # Training configuration
    learning_rate = 0.001
    train_size = 0.01
    batch_size = 32
    save_test_set = True

    # Hyper-parameters
    label_dim = 1
    image_size=64
    image_dim = [image_size, image_size, 3]
    latent_dim = 128
    beta = 0.65


    # root="/Dataset/COVIDx-splitted-resized-112"
    root="/Dataset/COVIDx-splitted-resized-112"
    data = pd.read_csv(f'{root}/train_dataset.csv')

    ##### Image data load
    image_data=[]
    image_data, valid_id = load_dataset(batch_size, image_size, root, data, split='train')
    dataset= image_data

    ##### label data load
    label_data = data[["covid_19"]].iloc[valid_id]
    label_dataset = tf.data.Dataset.from_tensor_slices(label_data).batch(batch_size)


    train_dataset={'img':image_data, 'labels':label_dataset}






    # Model
    encoder = Encoder(latent_dim)
    decoder = Decoder()
    model = ConvCVAE(
                    encoder,
                    decoder,
                    label_dim = label_dim,
                    latent_dim = latent_dim,
                    beta = beta,
                    image_dim = image_dim)

    # Optiizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)




    # Checkpoint path
    checkpoint_root = "./CVAE{}_{}_checkpoint".format(latent_dim, beta)
    checkpoint_name = "model"
    save_prefix = os.path.join(checkpoint_root, checkpoint_name)

    # Define the checkpoint
    checkpoint = tf.train.Checkpoint(module=model)

    ###
    # Restore the latest checkpoint
    latest = tf.train.latest_checkpoint(checkpoint_root)
    if latest is not None:
        checkpoint.restore(latest)
        print("Checkpoint restored:", latest)
    else:
      print("No checkpoint!")


    import numpy as np
    import time

    train_losses = []
    train_recon_errors = []
    train_latent_losses = []
    loss = []
    reconstruct_loss = []
    latent_loss = []

    step_index = 0
    n_batches = len(image_data)
    # n_batches= 10
    n_epochs = 30

    print("Number of epochs: {},  number of batches: {}".format(n_epochs, n_batches))

    # Epochs Loop
    for epoch in range(n_epochs):
        start_time = time.perf_counter()
        # dataset.shuffle() # Shuffling

        # Train Step Loop
        # for step_index, inputs in enumerate(train_dataset):
        for step_index, (img_batch, label_batch) in tqdm(enumerate(zip(train_dataset['img'], train_dataset['labels']))):

            # Todo: Normalizing images. Check if it harms
            img_batch = tf.image.per_image_standardization(img_batch)

            total_loss, recon_loss, lat_loss, recon_img = train_step([img_batch, label_batch], model, optimizer)
            train_losses.append(total_loss)
            train_recon_errors.append(recon_loss)
            train_latent_losses.append(lat_loss)

            if step_index + 1 == n_batches:
              break



        loss.append(np.mean(train_losses, 0))
        reconstruct_loss.append(np.mean(train_recon_errors, 0))
        latent_loss.append(np.mean(train_latent_losses, 0))

        exec_time = time.perf_counter() - start_time
        print("Execution time: %0.3f \t Epoch %i: loss %0.4f | reconstr loss %0.4f | latent loss %0.4f"
                            % (exec_time, epoch, loss[epoch], reconstruct_loss[epoch], latent_loss[epoch]))


        # Save progress every 5 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint.save(save_prefix + "_" + str(epoch + 1))
            print("Model saved:", save_prefix)

            # plot results
            n = 6
            f, axs = plt.subplots(2, n, figsize=(8, 3))
            for j, (img1, img2) in enumerate(zip(img_batch[0:n],recon_img[0:n])):
                axs[0][j].imshow(img1)
                axs[0][j].axis('off')

                axs[1][j].imshow(img2)
                axs[1][j].axis('off')

            plt.savefig(f'./recon_images/{epoch}.png')


    # Save the final model
    checkpoint.save(save_prefix)
    print("Model saved:", save_prefix)



